Code
# Setup device------------------------------------------------
from BI import bi
import jax.numpy as jnp
# Setup device------------------------------------------------
m = bi(platform='cpu', rand_seed = False)
# Simulate data ------------------------------------------------
N = 50
individual_predictor = m.dist.normal(0,1, shape = (N,1), sample = True)
kinship = m.dist.bernoulli(0.3, shape = (N,N), sample = True)
kinship = kinship.at[jnp.diag_indices(N)].set(0)
category = m.dist.categorical(jnp.array([.25,.25,.25,.25]), sample = True, shape = (N,))
N_grp, N_by_grp = jnp.unique(category, return_counts=True)
N_grp = N_grp.shape[0]
def sim_network(kinship, individual_predictor,category):
# Intercept
B_intercept = m.net.block_model(jnp.full((N,),0), 1, jnp.array([N]), sample = True)
B_category = m.net.block_model(category, N_grp, N_by_grp, sample = True)
# SR
sr = m.net.sender_receiver(
individual_predictor,
individual_predictor,
s_mu = 0.4, r_mu = -0.4, sample = True)
# D
DR = m.net.dyadic_effect(kinship, d_sd=2.5, sample = True)
return m.dist.bernoulli(
logits = B_intercept + B_category + sr + DR,
sample = True
)
network = sim_network(m.net.mat_to_edgl(kinship), individual_predictor, category)
# Predictive model ------------------------------------------------
m.data_on_model = dict(
network = network,
dyadic_predictors = m.net.mat_to_edgl(kinship),
focal_individual_predictors = individual_predictor,
target_individual_predictors = individual_predictor,
category = category
)
def model(network, dyadic_predictors, focal_individual_predictors, target_individual_predictors,category):
N_id = focal_individual_predictors.shape[0]
# Block ---------------------------------------
B_intercept = m.net.block_model(jnp.full((N_id,),0), 1, jnp.array([N_id]), name = "B_intercept")
B_category = m.net.block_model(category, N_grp, N_by_grp, name = "B_category")
## SR shape = N individuals---------------------------------------
sr = m.net.sender_receiver(
focal_individual_predictors,
target_individual_predictors,
s_mu = 0.4, r_mu = -0.4
)
# Dyadic shape = N dyads--------------------------------------
dr = m.net.dyadic_effect(dyadic_predictors, d_sd=2.5) # Diadic effect intercept only
m.dist.bernoulli(logits = B_intercept + B_category + sr + dr, obs=network)
m.fit(model, progress_bar=False)
m.summary()/home/sosa/work/3.12venv/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning:
IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
BI v 0.0.46 package loaded
jax.local_device_count 32
/home/sosa/work/BI/BI/Diagnostic/jax_diagnostics.py:214: RuntimeWarning:
invalid value encountered in scalar divide
| mean | sd | hdi_5.5% | hdi_94.5% | mcse_mean | mcse_sd | ess_bulk | ess_tail | r_hat | |
|---|---|---|---|---|---|---|---|---|---|
| b_B_category[0, 0] | -5.00 | 1.48 | -7.31 | -2.63 | 0.03 | 0.02 | 1874.61 | 2858.80 | 1.00 |
| b_B_category[0, 1] | -7.51 | 1.65 | -10.23 | -5.00 | 0.03 | 0.02 | 2509.65 | 3070.09 | 1.00 |
| b_B_category[0, 2] | -3.53 | 1.30 | -5.48 | -1.41 | 0.03 | 0.02 | 1640.90 | 2364.57 | 1.00 |
| b_B_category[0, 3] | -2.17 | 1.50 | -4.52 | 0.24 | 0.03 | 0.02 | 2173.03 | 2867.06 | 1.00 |
| b_B_category[1, 0] | -6.16 | 1.37 | -8.34 | -4.01 | 0.03 | 0.02 | 1679.45 | 2554.26 | 1.00 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| sr_rf[48, 1] | 4.07 | 1.42 | 1.78 | 6.24 | 0.04 | 0.03 | 1035.00 | 1828.40 | 1.00 |
| sr_rf[49, 0] | 1.88 | 0.96 | 0.40 | 3.43 | 0.03 | 0.02 | 1409.50 | 1947.23 | 1.00 |
| sr_rf[49, 1] | 0.79 | 1.61 | -1.85 | 3.25 | 0.04 | 0.03 | 2032.87 | 2557.14 | 1.00 |
| sr_sigma[0] | 1.12 | 0.30 | 0.62 | 1.56 | 0.01 | 0.01 | 580.39 | 886.83 | 1.01 |
| sr_sigma[1] | 4.83 | 0.86 | 3.45 | 6.11 | 0.03 | 0.02 | 608.53 | 1287.40 | 1.01 |
5131 rows × 9 columns